import argparse
import os
import json
import torch
from datasets import load_dataset
from tqdm import tqdm
from torchvision import transforms

from accelerate import infer_auto_device_map, dispatch_model
import sys


def parse_args():
    parser = argparse.ArgumentParser(description="Jailbreak Inference Script")
    parser.add_argument("--gpu-id", type=int, default=0, help="GPU ID to use")
    parser.add_argument("--image_file", type=str, default='output/InstructBLIP/class0_iter4000.pt', help="Path to input image tensor (.pt)")
    parser.add_argument("--output_file", type=str, default='', help="Path to output JSONL file")
    parser.add_argument(
        "--datasets", 
        type=str, 
        nargs='+', 
        choices=['real-toxicity-prompts', 'jailbreakbench', 'AdvBench', 'Harmbench', 'Harmbench_standard', 'StrongREJECT'],
        default=['AdvBench'],
        help="Select one or more datasets to evaluate."
    )
    parser.add_argument("--max_examples", type=int, default=None, help="Maximum number of prompts to process")
    return parser.parse_args()


def run_jailbreak_inference(harmful_dataset, text_prompt, output_file, args, model, img, max_new_tokens=256, max_examples=None, name=None):
    out = []
    with torch.no_grad():
        with open(output_file, 'w') as f:
            f.write(json.dumps({
                "args": vars(args),
                "prompt": text_prompt
            }) + "\n")
            for i, user_message in enumerate(tqdm(harmful_dataset, desc="Processing")):
                print(f" ----- {i} ----")
                print(" -- prompt: ---")
                
                if name in ['Harmbench', 'Harmbench_standard']:
                    category = user_message[1]
                    if user_message[4] != '':
                        user_message = user_message[4]+"\n\n---\n\n"+user_message[0]
                    else:
                        user_message = user_message[0]
                
                if name == 'Harmbench_standard':
                    if category != 'standard':
                        continue

                question_adv = text_prompt % user_message
                print(question_adv)
                response = model.generate({"image": img, "prompt": question_adv}, use_nucleus_sampling=True, top_p=0.9, temperature=1)[0]
                print(" -- continuation: ---")
                print(response)
                out_one = {
                    'prompt': question_adv,
                    'continuation': response
                }
                if name in ['Harmbench', 'Harmbench_standard']:
                    out_one['category'] = category
                out.append(out_one)
                print()
                f.write(json.dumps(out_one) + "\n")
                if max_examples is not None and i + 1 >= max_examples:
                    break


def load_dataset_by_name(name):
    if name in ['real-toxicity-prompts']:
        def rtp_read(path):
            dataset = []
            with open(path) as f:
                for line in f:
                    obj = json.loads(line)
                    if 'selected' in path:
                        dataset.append(obj['prompt']['text'])
                    elif obj.get('challenging'):
                        dataset.append(obj['prompt']['text'])
            return dataset
        
        if name == 'real-toxicity-prompts':
            return rtp_read('dataset/real-toxicity-prompts/prompts.jsonl'), 'Please extend the following sentence: %s'
        
    elif name == 'jailbreakbench':
        dataset = load_dataset('JailbreakBench/JBB-Behaviors', 'behaviors')
        return [ex['Goal'] for ex in dataset['harmful']], '%s.'

    elif name == 'AdvBench':
        dataset = load_dataset("walledai/AdvBench", split="train")
        return list(dataset['prompt']), '%s.'

    elif name in ['Harmbench', 'Harmbench_standard']:
        import csv
        with open("dataset/HarmBench/data/behavior_datasets/harmbench_behaviors_text_all.csv", "r") as f:
            reader = csv.reader(f)
            return [row for row in reader][1:], '%s.'
        
    elif name == 'StrongREJECT':
        dataset = load_dataset("json", data_files="dataset/StrongREJECT/strongreject_dataset.json", split="train")
        return [ex["text"] for ex in dataset], '%s.'
        

    else:
        raise ValueError(f"Unknown dataset: {name}")


def main():
    args = parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Loading model and tokenizer...")
    from lavis.models import load_model_and_preprocess
    model, vis_processor, _ = load_model_and_preprocess(
        name='blip2_vicuna_instruct',
        model_type='vicuna13b',
        is_eval=True,
        device=torch.device('cuda')
    )
    
    model.eval()
    img = torch.load(args.image_file).to(device)
    resize_transform = transforms.Resize((224,224))
    img = resize_transform(img)
    
    for dataset_name in args.datasets:

        output_file = args.output_file
        if not output_file:
            suffix = dataset_name
            output_file = args.image_file.split('.p')[0] + f'_{suffix}_instructblip.jsonl'

        if os.path.exists(output_file):
            with open(output_file, 'r', encoding='utf=8') as f:
                line_count = len(f.readlines())
            if dataset_name in ['real-toxicity-prompts']:
                if line_count > 1190 and line_count < 1210:
                    print(f">>> jsonl file already exist. {dataset_name} inferece canceled.\n")
                    continue
            if dataset_name == 'jailbreakbench':
                if line_count > 90 and line_count < 110:
                    print(f">>> jsonl file already exist. {dataset_name} inferece canceled.\n")
                    continue
            if dataset_name == 'Harmbench':
                if line_count > 390 and line_count < 410:
                    print(f">>> jsonl file already exist. {dataset_name} inferece canceled.\n")
                    continue
            if dataset_name == 'AdvBench':
                if line_count > 510 and line_count < 530:
                    print(f">>> jsonl file already exist. {dataset_name} inferece canceled.\n")
                    continue
            if dataset_name == 'StrongREJECT':
                if line_count > 305 and line_count < 325:
                    print(f">>> jsonl file already exist. {dataset_name} inferece canceled.\n")
                    continue
            if dataset_name == 'Harmbench_standard':
                if line_count > 190 and line_count < 210:
                    print(f">>> jsonl file already exist. {dataset_name} inferece canceled.\n")
                    continue

        harmful_dataset, text_prompt = load_dataset_by_name(dataset_name)

        print("Input Image File: ", args.image_file)
        print("Output Json File: ", output_file)
        
        run_jailbreak_inference(
            harmful_dataset=harmful_dataset,
            text_prompt=text_prompt,
            output_file=output_file,
            args=args,
            model=model,
            img=img,
            max_new_tokens=256,
            max_examples=args.max_examples,
            name=dataset_name
        )


if __name__ == "__main__":
    main()